{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bayesian Imputation for Missing Values in Discrete Covariates\n",
    "\n",
    "Missing data is a very widespread problem in practical applications, both in covariates ('explanatory variables') and outcomes.\n",
    "When performing bayesian inference with MCMC, imputing discrete missing values is not possible using Hamiltonian Monte Carlo techniques.\n",
    "One way around this problem is to create a new model that enumerates the discrete variables and does inference over the new model, which, for a single discrete variable, is a mixture model. (see e.g. [Stan's user guide on Latent Discrete Parameters](https://mc-stan.org/docs/2_18/stan-users-guide/change-point-section.html))\n",
    "Enumerating the discrete latent sites requires some manual math work that can get tedious for complex models.\n",
    "Inference by automatic enumeration of discrete variables is implemented in numpyro and allows for a very convenient way of dealing with missing discrete data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro funsor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import inf\n",
    "\n",
    "from graphviz import Digraph\n",
    "\n",
    "from jax import numpy as jnp, random\n",
    "from jax.scipy.special import expit\n",
    "\n",
    "import numpyro\n",
    "from numpyro import distributions as dist, sample\n",
    "from numpyro.infer.hmc import NUTS\n",
    "from numpyro.infer.mcmc import MCMC\n",
    "\n",
    "simkeys = random.split(random.PRNGKey(0), 10)\n",
    "nsim = 5000\n",
    "mcmc_key = random.PRNGKey(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we will simulate data with correlated binary covariates. The assumption is that we wish to estimate parameter for some parametric model without bias (e.g. for inferring a causal effect). For several different missing data patterns we will see how to impute the values to lead to unbiased models. \n",
    "\n",
    "The basic data structure is as follows. Z is a latent variable that gives rise to the marginal dependence between A and B, the observed covariates. We will consider different missing data mechanisms for variable A, where variable B and Y are fully observed. The effects of A and B on Y are the effects of interest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"134pt\" height=\"188pt\"\n",
       " viewBox=\"0.00 0.00 134.00 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 130,-184 130,4 -4,4\"/>\n",
       "<!-- A -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>A</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">A</text>\n",
       "</g>\n",
       "<!-- Y -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>Y</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">Y</text>\n",
       "</g>\n",
       "<!-- A&#45;&gt;Y -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>A&#45;&gt;Y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M35.35,-72.76C39.71,-64.28 45.15,-53.71 50.04,-44.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"53.23,-45.64 54.7,-35.15 47.01,-42.44 53.23,-45.64\"/>\n",
       "</g>\n",
       "<!-- B -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>B</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">B</text>\n",
       "</g>\n",
       "<!-- B&#45;&gt;Y -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>B&#45;&gt;Y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M90.65,-72.76C86.29,-64.28 80.85,-53.71 75.96,-44.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.99,-42.44 71.3,-35.15 72.77,-45.64 78.99,-42.44\"/>\n",
       "</g>\n",
       "<!-- Z -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>Z</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">Z</text>\n",
       "</g>\n",
       "<!-- Z&#45;&gt;A -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>Z&#45;&gt;A</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M54.65,-144.76C50.29,-136.28 44.85,-125.71 39.96,-116.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"42.99,-114.44 35.3,-107.15 36.77,-117.64 42.99,-114.44\"/>\n",
       "</g>\n",
       "<!-- Z&#45;&gt;B -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>Z&#45;&gt;B</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M71.35,-144.76C75.71,-136.28 81.15,-125.71 86.04,-116.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"89.23,-117.64 90.7,-107.15 83.01,-114.44 89.23,-117.64\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7fdd0821b280>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dot = Digraph()\n",
    "dot.node(\"A\")\n",
    "dot.node(\"B\")\n",
    "dot.node(\"Z\")\n",
    "dot.node(\"Y\")\n",
    "dot.edges([\"ZA\", \"ZB\", \"AY\", \"BY\"])\n",
    "dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_A = 0.25\n",
    "b_B = 0.25\n",
    "s_Y = 0.25\n",
    "Z = random.normal(simkeys[0], (nsim,))\n",
    "A = random.bernoulli(simkeys[1], expit(Z))\n",
    "B = random.bernoulli(simkeys[2], expit(Z))\n",
    "Y = A * b_A + B * b_B + s_Y * random.normal(simkeys[3], (nsim,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MAR conditional on outcome\n",
    "\n",
    "According to Rubin's classic definitions there are 3 distinct of missing data mechanisms:\n",
    "\n",
    "1. missing completely at random (MCAR)\n",
    "2. missing at random, conditional on observed data (MAR)\n",
    "3. missing not at random, even after conditioning on observed data (MNAR)\n",
    "\n",
    "Missing data mechanisms 1. and 2. are 'easy' to handle as they depend on observed data only.\n",
    "Mechanism 3. (MNAR) is trickier as it depends on data that is not observed, but may still be relevant to the outcome you are modeling (see below for a concrete example).\n",
    "\n",
    "First we will generate missing values in A, conditional on the value of Y (thus it is a MAR mechanism)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"170pt\" height=\"188pt\"\n",
       " viewBox=\"0.00 0.00 170.00 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 166,-184 166,4 -4,4\"/>\n",
       "<!-- Y -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>Y</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">Y</text>\n",
       "</g>\n",
       "<!-- M -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>M</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"135\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"135\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">M</text>\n",
       "</g>\n",
       "<!-- Y&#45;&gt;M -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>Y&#45;&gt;M</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M90,-18C92.61,-18 95.23,-18 97.84,-18\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"97.93,-21.5 107.93,-18 97.93,-14.5 97.93,-21.5\"/>\n",
       "</g>\n",
       "<!-- A -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>A</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">A</text>\n",
       "</g>\n",
       "<!-- A&#45;&gt;Y -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>A&#45;&gt;Y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M35.35,-72.76C39.71,-64.28 45.15,-53.71 50.04,-44.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"53.23,-45.64 54.7,-35.15 47.01,-42.44 53.23,-45.64\"/>\n",
       "</g>\n",
       "<!-- B -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>B</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">B</text>\n",
       "</g>\n",
       "<!-- B&#45;&gt;Y -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>B&#45;&gt;Y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M90.65,-72.76C86.29,-64.28 80.85,-53.71 75.96,-44.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.99,-42.44 71.3,-35.15 72.77,-45.64 78.99,-42.44\"/>\n",
       "</g>\n",
       "<!-- Z -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>Z</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">Z</text>\n",
       "</g>\n",
       "<!-- Z&#45;&gt;A -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>Z&#45;&gt;A</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M54.65,-144.76C50.29,-136.28 44.85,-125.71 39.96,-116.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"42.99,-114.44 35.3,-107.15 36.77,-117.64 42.99,-114.44\"/>\n",
       "</g>\n",
       "<!-- Z&#45;&gt;B -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>Z&#45;&gt;B</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M71.35,-144.76C75.71,-136.28 81.15,-125.71 86.04,-116.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"89.23,-117.64 90.7,-107.15 83.01,-114.44 89.23,-117.64\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7fdd0a4b0a90>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dot_mnar_y = Digraph()\n",
    "with dot_mnar_y.subgraph() as s:\n",
    "    s.attr(rank=\"same\")\n",
    "    s.node(\"Y\")\n",
    "    s.node(\"M\")\n",
    "dot_mnar_y.node(\"A\")\n",
    "dot_mnar_y.node(\"B\")\n",
    "dot_mnar_y.node(\"Z\")\n",
    "dot_mnar_y.node(\"M\")\n",
    "dot_mnar_y.edges([\"YM\", \"ZA\", \"ZB\", \"AY\", \"BY\"])\n",
    "dot_mnar_y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This graph depicts the datagenerating mechanism, where Y is the only cause of missingness in A, denoted M. This means that the missingness in M is random, conditional on Y.\n",
    "\n",
    "As an example consider this simplified scenario:\n",
    "\n",
    "- A represents a history of heart illness\n",
    "- B represents the age of a patient\n",
    "- Y represents whether or not the patient will visit the general practitioner\n",
    "\n",
    "A general practitioner wants to find out why patients that are assigned to her clinic will visit the clinic or not. She thinks that having a history of heart illness and age are potential causes of doctor visits. Data on patient ages are available through their registration forms, but information on prior heart illness may be availalbe only after they have visited the clinic. This makes the missingness in A (history of heart disease), dependent on the outcome (visiting the clinic)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_isobs = random.bernoulli(simkeys[4], expit(3 * (Y - Y.mean())))\n",
    "Aobs = jnp.where(A_isobs, A, -1)\n",
    "A_obsidx = jnp.where(A_isobs)\n",
    "\n",
    "# generate complete case arrays\n",
    "Acc = Aobs[A_obsidx]\n",
    "Bcc = B[A_obsidx]\n",
    "Ycc = Y[A_obsidx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will evaluate 2 approaches:\n",
    "\n",
    "1. complete case analysis (which will lead to biased inferences)\n",
    "2. with imputation (conditional on B)\n",
    "    \n",
    "Note that explicitly including Y in the imputation model for A is unneccesary.\n",
    "The sampled imputations for A will condition on Y indirectly as the likelihood of Y is conditional on A.\n",
    "So values of A that give high likelihood to Y will be sampled more often than other values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ccmodel(A, B, Y):\n",
    "    ntotal = A.shape[0]\n",
    "    # get parameters of outcome model\n",
    "    b_A = sample(\"b_A\", dist.Normal(0, 2.5))\n",
    "    b_B = sample(\"b_B\", dist.Normal(0, 2.5))\n",
    "    s_Y = sample(\"s_Y\", dist.HalfCauchy(2.5))\n",
    "\n",
    "    with numpyro.plate(\"obs\", ntotal):\n",
    "        ### outcome model\n",
    "        eta_Y = b_A * A + b_B * B\n",
    "        sample(\"obs_Y\", dist.Normal(eta_Y, s_Y), obs=Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 1000/1000 [00:02<00:00, 348.50it/s, 3 steps of size 4.27e-01. acc. prob=0.94]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "       b_A      0.30      0.01      0.30      0.29      0.31    500.83      1.00\n",
      "       b_B      0.28      0.01      0.28      0.27      0.29    546.34      1.00\n",
      "       s_Y      0.25      0.00      0.25      0.24      0.25    559.55      1.00\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "cckernel = NUTS(ccmodel)\n",
    "ccmcmc = MCMC(cckernel, num_warmup=250, num_samples=750)\n",
    "ccmcmc.run(mcmc_key, Acc, Bcc, Ycc)\n",
    "ccmcmc.print_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def impmodel(A, B, Y):\n",
    "    ntotal = A.shape[0]\n",
    "    A_isobs = A >= 0\n",
    "\n",
    "    # get parameters of imputation model\n",
    "    mu_A = sample(\"mu_A\", dist.Normal(0, 2.5))\n",
    "    b_B_A = sample(\"b_B_A\", dist.Normal(0, 2.5))\n",
    "\n",
    "    # get parameters of outcome model\n",
    "    b_A = sample(\"b_A\", dist.Normal(0, 2.5))\n",
    "    b_B = sample(\"b_B\", dist.Normal(0, 2.5))\n",
    "    s_Y = sample(\"s_Y\", dist.HalfCauchy(2.5))\n",
    "\n",
    "    with numpyro.plate(\"obs\", ntotal):\n",
    "        ### imputation model\n",
    "        # get linear predictor for missing values\n",
    "        eta_A = mu_A + B * b_B_A\n",
    "\n",
    "        # sample imputation values for A\n",
    "        # mask out to not add log_prob to total likelihood right now\n",
    "        Aimp = sample(\n",
    "            \"A\",\n",
    "            dist.Bernoulli(logits=eta_A).mask(False),\n",
    "            infer={\"enumerate\": \"parallel\"},\n",
    "        )\n",
    "\n",
    "        # 'manually' calculate the log_prob\n",
    "        log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)\n",
    "\n",
    "        # cancel out enumerated values that are not equal to observed values\n",
    "        log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)\n",
    "\n",
    "        # add to total likelihood for sampler\n",
    "        numpyro.factor(\"A_obs\", log_prob)\n",
    "\n",
    "        ### outcome model\n",
    "        eta_Y = b_A * Aimp + b_B * B\n",
    "        sample(\"obs_Y\", dist.Normal(eta_Y, s_Y), obs=Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 1000/1000 [00:05<00:00, 174.83it/s, 7 steps of size 4.41e-01. acc. prob=0.91]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "       b_A      0.25      0.01      0.25      0.24      0.27    447.79      1.01\n",
      "       b_B      0.25      0.01      0.25      0.24      0.26    570.66      1.01\n",
      "     b_B_A      0.74      0.08      0.74      0.60      0.86    316.36      1.00\n",
      "      mu_A     -0.39      0.06     -0.39     -0.48     -0.29    290.86      1.00\n",
      "       s_Y      0.25      0.00      0.25      0.25      0.25    527.97      1.00\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "impkernel = NUTS(impmodel)\n",
    "impmcmc = MCMC(impkernel, num_warmup=250, num_samples=750)\n",
    "impmcmc.run(mcmc_key, Aobs, B, Y)\n",
    "impmcmc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, when data are missing conditionally on Y, imputation leads to consistent estimation of the parameter of interest (b_A and b_B)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNAR conditional on covariate\n",
    "\n",
    "When data are missing conditional on unobserved data, things get more tricky.\n",
    "Here we will generate missing values in A, conditional on the value of A itself (missing not at random (MNAR), but missing at random conditional on A).\n",
    "\n",
    "As an example consider patients who have cancer:\n",
    "\n",
    "- A represents weight loss\n",
    "- B represents age\n",
    "- Y represents overall survival time\n",
    "\n",
    "Both A and B can be related to survival time in cancer patients. For patients who have extreme weight loss, it is more likely that this will be noted by the doctor and registered in the electronic health record. For patients with no weight loss or little weight loss, it may be that the doctor forgets to ask about it and therefore does not register it in the records."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"134pt\" height=\"188pt\"\n",
       " viewBox=\"0.00 0.00 134.00 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 130,-184 130,4 -4,4\"/>\n",
       "<!-- B -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>B</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">B</text>\n",
       "</g>\n",
       "<!-- Y -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>Y</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">Y</text>\n",
       "</g>\n",
       "<!-- B&#45;&gt;Y -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>B&#45;&gt;Y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M27,-71.7C27,-63.98 27,-54.71 27,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-46.1 27,-36.1 23.5,-46.1 30.5,-46.1\"/>\n",
       "</g>\n",
       "<!-- Z -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>Z</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">Z</text>\n",
       "</g>\n",
       "<!-- Z&#45;&gt;B -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>Z&#45;&gt;B</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M54.65,-144.76C50.29,-136.28 44.85,-125.71 39.96,-116.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"42.99,-114.44 35.3,-107.15 36.77,-117.64 42.99,-114.44\"/>\n",
       "</g>\n",
       "<!-- A -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>A</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">A</text>\n",
       "</g>\n",
       "<!-- Z&#45;&gt;A -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>Z&#45;&gt;A</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M71.35,-144.76C75.71,-136.28 81.15,-125.71 86.04,-116.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"89.23,-117.64 90.7,-107.15 83.01,-114.44 89.23,-117.64\"/>\n",
       "</g>\n",
       "<!-- A&#45;&gt;Y -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>A&#45;&gt;Y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M84.43,-74.83C74.25,-64.94 60.48,-51.55 48.97,-40.36\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"51.41,-37.85 41.8,-33.38 46.53,-42.87 51.41,-37.85\"/>\n",
       "</g>\n",
       "<!-- M -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>M</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">M</text>\n",
       "</g>\n",
       "<!-- A&#45;&gt;M -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>A&#45;&gt;M</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M99,-71.7C99,-63.98 99,-54.71 99,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-46.1 99,-36.1 95.5,-46.1 102.5,-46.1\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7fdcf828eb20>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dot_mnar_x = Digraph()\n",
    "with dot_mnar_y.subgraph() as s:\n",
    "    s.attr(rank=\"same\")\n",
    "    s.node(\"A\")\n",
    "    s.node(\"M\")\n",
    "dot_mnar_x.node(\"B\")\n",
    "dot_mnar_x.node(\"Z\")\n",
    "dot_mnar_x.node(\"Y\")\n",
    "dot_mnar_x.edges([\"AM\", \"ZA\", \"ZB\", \"AY\", \"BY\"])\n",
    "dot_mnar_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_isobs = random.bernoulli(simkeys[5], 0.9 - 0.8 * A)\n",
    "Aobs = jnp.where(A_isobs, A, -1)\n",
    "A_obsidx = jnp.where(A_isobs)\n",
    "\n",
    "# generate complete case arrays\n",
    "Acc = Aobs[A_obsidx]\n",
    "Bcc = B[A_obsidx]\n",
    "Ycc = Y[A_obsidx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 1000/1000 [00:02<00:00, 342.07it/s, 3 steps of size 5.97e-01. acc. prob=0.92]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "       b_A      0.27      0.02      0.26      0.24      0.29    667.08      1.01\n",
      "       b_B      0.25      0.01      0.25      0.24      0.26    811.49      1.00\n",
      "       s_Y      0.25      0.00      0.25      0.24      0.25    547.51      1.00\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "cckernel = NUTS(ccmodel)\n",
    "ccmcmc = MCMC(cckernel, num_warmup=250, num_samples=750)\n",
    "ccmcmc.run(mcmc_key, Acc, Bcc, Ycc)\n",
    "ccmcmc.print_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 1000/1000 [00:06<00:00, 166.36it/s, 7 steps of size 4.10e-01. acc. prob=0.94]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "       b_A      0.34      0.01      0.34      0.32      0.35    576.15      1.00\n",
      "       b_B      0.33      0.01      0.33      0.32      0.34    800.58      1.00\n",
      "     b_B_A      0.32      0.12      0.32      0.12      0.51    342.21      1.01\n",
      "      mu_A     -1.81      0.09     -1.81     -1.95     -1.67    288.57      1.00\n",
      "       s_Y      0.26      0.00      0.26      0.25      0.26    820.20      1.00\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "impkernel = NUTS(impmodel)\n",
    "impmcmc = MCMC(impkernel, num_warmup=250, num_samples=750)\n",
    "impmcmc.run(mcmc_key, Aobs, B, Y)\n",
    "impmcmc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perhaps surprisingly, imputing missing values when the missingness mechanism depends on the variable itself will actually lead to bias, while complete case analysis is unbiased!\n",
    "See e.g. [Bias and efficiency of multiple imputation compared with complete‐case analysis for missing covariate values](https://doi.org/10.1002/sim.3944).\n",
    "\n",
    "However, complete case analysis may be undesirable as well. E.g. due to leading to lower precision in estimating the parameter from B to Y, or maybe when there is an expected difference interaction between the value of A and the parameter from A to Y. To deal with this situation, an explicit model for the reason of missingness (/observation) is required. We will add one below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def impmissmodel(A, B, Y):\n",
    "    ntotal = A.shape[0]\n",
    "    A_isobs = A >= 0\n",
    "\n",
    "    # get parameters of imputation model\n",
    "    mu_A = sample(\"mu_A\", dist.Normal(0, 2.5))\n",
    "    b_B_A = sample(\"b_B_A\", dist.Normal(0, 2.5))\n",
    "\n",
    "    # get parameters of outcome model\n",
    "    b_A = sample(\"b_A\", dist.Normal(0, 2.5))\n",
    "    b_B = sample(\"b_B\", dist.Normal(0, 2.5))\n",
    "    s_Y = sample(\"s_Y\", dist.HalfCauchy(2.5))\n",
    "\n",
    "    # get parameter of model of missingness\n",
    "    with numpyro.plate(\"obsmodel\", 2):\n",
    "        p_Aobs = sample(\"p_Aobs\", dist.Beta(1, 1))\n",
    "\n",
    "    with numpyro.plate(\"obs\", ntotal):\n",
    "        ### imputation model\n",
    "        # get linear predictor for missing values\n",
    "        eta_A = mu_A + B * b_B_A\n",
    "\n",
    "        # sample imputation values for A\n",
    "        # mask out to not add log_prob to total likelihood right now\n",
    "        Aimp = sample(\n",
    "            \"A\",\n",
    "            dist.Bernoulli(logits=eta_A).mask(False),\n",
    "            infer={\"enumerate\": \"parallel\"},\n",
    "        )\n",
    "\n",
    "        # 'manually' calculate the log_prob\n",
    "        log_prob = dist.Bernoulli(logits=eta_A).log_prob(Aimp)\n",
    "\n",
    "        # cancel out enumerated values that are not equal to observed values\n",
    "        log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)\n",
    "\n",
    "        # add to total likelihood for sampler\n",
    "        numpyro.factor(\"obs_A\", log_prob)\n",
    "\n",
    "        ### outcome model\n",
    "        eta_Y = b_A * Aimp + b_B * B\n",
    "        sample(\"obs_Y\", dist.Normal(eta_Y, s_Y), obs=Y)\n",
    "\n",
    "        ### missingness / observationmodel\n",
    "        eta_Aobs = jnp.where(Aimp, p_Aobs[0], p_Aobs[1])\n",
    "        sample(\"obs_Aobs\", dist.Bernoulli(probs=eta_Aobs), obs=A_isobs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 1000/1000 [00:09<00:00, 106.81it/s, 7 steps of size 2.86e-01. acc. prob=0.91] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "       b_A      0.26      0.01      0.26      0.24      0.27    267.57      1.00\n",
      "       b_B      0.25      0.01      0.25      0.24      0.26    537.10      1.00\n",
      "     b_B_A      0.74      0.07      0.74      0.62      0.84    421.54      1.00\n",
      "      mu_A     -0.45      0.08     -0.45     -0.58     -0.31    241.11      1.00\n",
      " p_Aobs[0]      0.10      0.01      0.10      0.09      0.11    451.90      1.00\n",
      " p_Aobs[1]      0.86      0.03      0.86      0.82      0.91    244.47      1.00\n",
      "       s_Y      0.25      0.00      0.25      0.24      0.25    375.51      1.00\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "impmisskernel = NUTS(impmissmodel)\n",
    "impmissmcmc = MCMC(impmisskernel, num_warmup=250, num_samples=750)\n",
    "impmissmcmc.run(mcmc_key, Aobs, B, Y)\n",
    "impmissmcmc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now estimate the parameters b_A and b_B without bias, while still utilizing all observations.\n",
    "Obviously, modeling the missingness mechanism relies on assumptions that need either be substantiated with prior evidence, or possibly analyzed through sensitivity analysis.\n",
    "\n",
    "For more reading on missing data in bayesian inference, see:\n",
    "\n",
    "- [Presentation Bayesian Methods for missing data (pdf)](https://www.bayes-pharma.org/Abstracts2013/slides/NickyBest_MissingData.pdf)\n",
    "- [Bayesian Approaches for Missing Not at\n",
    "Random Outcome Data: The Role of\n",
    "Identifying Restrictions (doi:10.1214/17-STS630)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6936760/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
